clear

%set a name for the estimation output
estimation_name='abc_2t_rn';

%%%%%%%%%%%%%%%%%%%%%%%%
%estimation settings
%%%%%%%%%%%%%%%%%%%%%%%%

%set number of types
n_types=2;

%set risk neutral or not
risk_neutral=1;         %if set to zero, takes r=1 for crra utility
                        %if set to one, assumes risk neutrality
                        %if set to two, estimates r as a free parameter
                        
%set constraints
alpha_zero=0;           %sets alpha to zero
beta_zero=0;            %sets beta to zero
kappa_zero=0;           %sets kappa to zero
delta_zero=1;           %sets delta to zero
gamma_zero=1;           %sets gamma to zero
pms_zero = [alpha_zero beta_zero kappa_zero delta_zero gamma_zero];

%set starting values
starting_values_type1 = [0.12 0.37 0.10 0.00 0.00 8.45];
starting_values_type2 = [0.18 0.01 0.10 0.00 0.00 4.25];
starting_values_type3 = [0.00 0.00 0.00 0.00 0.00 0.00];

%settings for the EM-algorithm
n_steps = 1;            % number of starting points used
n_iterations = 50;      % maximal number of iterations
ll_diff_stop = .01;     % difference in log-likelihood for convergence

sts = [n_types risk_neutral pms_zero n_steps n_iterations ll_diff_stop];

if n_types == 1
    starting_values = starting_values_type1;
elseif n_types == 2
    starting_values = [starting_values_type1 starting_values_type2];
elseif n_types == 3
    starting_values = [starting_values_type1 starting_values_type2 starting_values_type3];
end

%import choices from experiment
filename='mnl_data_core.csv';
dt=readmatrix(filename);

[m_dt,n_dt] = size(dt);
no_trials=18;
n = length(dt)/no_trials; % n is the number of subjects in the data

% cut up the data 
est_data = zeros(n*(no_trials-1),n_dt,no_trials);
hold_out_data = zeros(n,n_dt,no_trials);

for i = 1:no_trials
    
    ind1 = dt(:,2) == i;
    ind2 = dt(:,2) ~= i;
    
    hold_out_data(:,:,i) = dt(ind1,:);
    est_data(:,:,i) = dt(ind2,:);
end

Results = [];
Results2 = [];
choice_probs = [];
predictions = [];

if risk_neutral==2
    n_params = 7;
else
    n_params = 6;
end

parfor t = 1:no_trials

    %estimate
    est_dt = est_data(:,:,t);
    [taus, estimates] = hm_mix_est(est_dt,sts,starting_values);
    Results = [Results;estimates];     
    Results2 = [Results2;taus];   
    
    %predict individual choices   
    for i = 1:n
        sub1 = hold_out_data(i,:,t);
        id = sub1(1); 
        hold_out = [sub1(3:11),0,id,sub1(2)];    
                
        pms=zeros(1,7);
        subtaus = taus(i,:);
        [M_t,I_t]=max(subtaus);
        type = I_t;                     
        pms(1:6) = estimates(n_params*(type-1)+1:n_params*(type-1)+6);
        
        if risk_neutral==2
            pms(7) = estimates(7);              % risk parameter 
        elseif risk_neutral==0
            pms(7) = 1;                         % risk parameter
        else
            pms(7) = 0;                         % risk parameter
        end

        probs = hm_probs(pms,hold_out,2);       % 2 implies using the risk parameter in pms(7)
        choice_probs = [choice_probs;probs];
        
        choice = hold_out(8);
        preds = zeros(1,31);
        preds(1:12) = hold_out;
        preds(13:17) = pms(1:5);                % estimated alpha, beta, kappa, delta, gamma 
        preds(18) = pms(7);                     % estimated risk parameter
        preds(19) = pms(6);                     % estimated lambda
        preds(20) = 0;                          % (LL in case of individual estimations)
        preds(21:28) = probs;
        preds(29) = probs(choice);      
        
        [M,I]=max(probs);       
        if I == choice
            correct = 1;
        else 
            correct = 0;
        end
        preds(30) = I;
        preds(31) = correct;
        
        predictions = [predictions;preds];   
    end 
    
end

save(fullfile('output','out_of_sample','mat_files',estimation_name));

T = array2table(predictions);
T_names = {'T','R','P','S','y1','y2','y3','choice','game_type','lot','id','game_no','a_est','b_est','c_est','d_est','e_est','r_est','l_est','ll','p1','p2','p3','p4','p5','p6','p7','p8','p_correct','pred_choice','choice_correct'};
T.Properties.VariableNames(1:31) = T_names;
writetable(T,fullfile('output','out_of_sample','predictions',estimation_name));
